from recbole.utils import InputType
from recbole_cdr.model.crossdomain_recommender import CrossDomainRecommender
from recbole.model.init import xavier_normal_initialization
from recbole.model.loss import BPRLoss
from recbole_cdr.data.dataset import CrossDomainDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple


# --- HELPERS -----------------------------------------------------------------

def truncate_history(seq_ids: torch.Tensor,
                     seq_scores: Optional[torch.Tensor],
                     max_len: int) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
    """Keep only the last *max_len* (or top-k by score when scores given)."""
    if seq_ids.size(1) <= max_len:
        return seq_ids, seq_scores

    if seq_scores is None:
        return seq_ids[:, -max_len:], None

    topk_val, topk_idx = seq_scores.topk(max_len, dim=1)
    b_idx = torch.arange(seq_ids.size(0), device=seq_ids.device).unsqueeze(-1)
    new_ids = seq_ids[b_idx, topk_idx]
    new_scores = seq_scores[b_idx, topk_idx]
    return new_ids, new_scores


# --- IMPROVED BEHAVIOR AGGREGATOR --------------------------------------------

class ImprovedBehaviorAggregator(nn.Module):
    def __init__(self,
                 embedding_dim: int,
                 aggregator_type: str = 'mean',
                 lambda_a: float = 0.5,
                 dropout_rate: float = 0.1,
                 num_heads: int = 4):
        super().__init__()
        self.aggregator = aggregator_type
        self.lambda_a = lambda_a
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads

        # 改进的聚合层
        self.W_agg = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.LayerNorm(embedding_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(embedding_dim, embedding_dim)
        )

        if self.aggregator == 'multi_head_attention':
            self.multihead_attn = nn.MultiheadAttention(
                embedding_dim, num_heads, dropout=dropout_rate, batch_first=True
            )
        elif self.aggregator == 'user_attention':
            self.W_att = nn.Sequential(
                nn.Linear(embedding_dim, embedding_dim),
                nn.Tanh(),
                nn.Dropout(dropout_rate)
            )

        self.dropout = nn.Dropout(dropout_rate)
        self.layer_norm = nn.LayerNorm(embedding_dim)

    @staticmethod
    def masked_softmax(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        x = x.masked_fill(mask, -1e9)
        return F.softmax(x, dim=-1)

    def mean_pooling(self, seq_emb: torch.Tensor) -> torch.Tensor:
        mask = seq_emb.sum(dim=-1) != 0
        mean = seq_emb.sum(dim=1) / (mask.float().sum(dim=-1, keepdim=True) + 1e-12)
        return mean

    def multi_head_attention_pooling(self, id_emb: torch.Tensor, seq_emb: torch.Tensor) -> torch.Tensor:
        # 使用用户embedding作为query
        query = id_emb.unsqueeze(1)  # [B, 1, D]
        key = value = seq_emb  # [B, L, D]

        # 创建padding mask
        key_padding_mask = (seq_emb.sum(dim=-1) == 0)  # [B, L]

        attn_output, _ = self.multihead_attn(
            query, key, value,
            key_padding_mask=key_padding_mask
        )
        return attn_output.squeeze(1)  # [B, D]

    def user_attention_pooling(self, id_emb: torch.Tensor, seq_emb: torch.Tensor) -> torch.Tensor:
        key = self.W_att(seq_emb)
        mask = seq_emb.sum(dim=-1) == 0
        attn = torch.bmm(key, id_emb.unsqueeze(-1)).squeeze(-1)
        attn = self.masked_softmax(attn, mask)
        attn = self.dropout(attn)
        out = torch.bmm(attn.unsqueeze(1), seq_emb).squeeze(1)
        return out

    def forward(self, id_emb: torch.Tensor, seq_emb: torch.Tensor,
                score: Optional[torch.Tensor] = None) -> torch.Tensor:

        if self.aggregator == 'mean':
            aggregated = self.mean_pooling(seq_emb)
        elif self.aggregator == 'multi_head_attention':
            aggregated = self.multi_head_attention_pooling(id_emb, seq_emb)
        elif self.aggregator == 'user_attention':
            aggregated = self.user_attention_pooling(id_emb, seq_emb)
        else:
            raise ValueError(f"Invalid aggregator type: {self.aggregator}")

        # 通过改进的聚合网络
        aggregated = self.W_agg(aggregated)

        # 残差连接 + 层归一化
        output = self.lambda_a * id_emb + (1 - self.lambda_a) * aggregated
        return self.layer_norm(output)


# --- DOMAIN ALIGNMENT MODULE -------------------------------------------------

class DomainAlignmentModule(nn.Module):
    def __init__(self, embedding_dim: int, hidden_dim: int = None):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = embedding_dim * 2

        # 域判别器 - 更复杂的架构
        self.domain_discriminator = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim // 2, 1)
        )

        # 特征对齐网络
        self.alignment_net = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.BatchNorm1d(embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, embedding_dim)
        )

    def forward(self, src_repr: torch.Tensor, tgt_repr: torch.Tensor):
        # 对齐特征
        aligned_src = self.alignment_net(src_repr)
        aligned_tgt = self.alignment_net(tgt_repr)

        # 域判别
        src_domain_pred = self.domain_discriminator(aligned_src)
        tgt_domain_pred = self.domain_discriminator(aligned_tgt)

        return aligned_src, aligned_tgt, src_domain_pred, tgt_domain_pred


# --- IMPROVED UniCDR --------------------------------------------------------

class UniCDR(CrossDomainRecommender):
    input_type = InputType.POINTWISE

    def __init__(self, config, dataset: CrossDomainDataset):
        super().__init__(config, dataset)

        # 超参数
        self.embedding_dim = config['embedding_size']
        self.aggregator_ty = config['aggregator_type']
        self.lambda_a = config['lambda_a']
        self.dropout_rate = config['drop_rate']
        self.lambda_cl = config['lambda_cl']
        self.lambda_domain = config['lambda_domain']
        self.mask_prob = config['mask_prob']
        self.max_his_len = config['max_his_len']
        self.temperature = config['temperature']

        # 损失函数
        self.loss_fn = BPRLoss() if config['loss_type'] == 'BPR' else nn.BCEWithLogitsLoss()

        # 嵌入层
        self.specific_user_emb = nn.ModuleList([
            nn.Embedding(self.total_num_users, self.embedding_dim) for _ in range(2)
        ])
        self.specific_item_emb = nn.ModuleList([
            nn.Embedding(self.total_num_items, self.embedding_dim, padding_idx=0) for _ in range(2)
        ])
        self.shared_user_emb = nn.Embedding(self.total_num_users, self.embedding_dim)

        # 新增：共享item embedding用于跨域对齐
        self.shared_item_emb = nn.Embedding(self.total_num_items, self.embedding_dim, padding_idx=0)

        # 改进的聚合器
        self.specific_agg = nn.ModuleList([
            ImprovedBehaviorAggregator(
                self.embedding_dim, self.aggregator_ty,
                self.lambda_a, self.dropout_rate
            ) for _ in range(2)
        ])
        self.shared_agg = ImprovedBehaviorAggregator(
            self.embedding_dim, self.aggregator_ty,
            self.lambda_a, self.dropout_rate
        )

        # 域对齐模块
        self.domain_alignment = DomainAlignmentModule(self.embedding_dim)

        # 预计算历史
        (self.src_hist_id, self.src_hist_val, self.src_hist_len) = dataset.history_item_matrix(domain='source')
        (self.tgt_hist_id, self.tgt_hist_val, self.tgt_hist_len) = dataset.history_item_matrix(domain='target')
        for name in ['src_hist_id', 'src_hist_val', 'src_hist_len',
                     'tgt_hist_id', 'tgt_hist_val', 'tgt_hist_len']:
            setattr(self, name, getattr(self, name).to(self.device))

        self.apply(xavier_normal_initialization)

    def contrastive_loss(self, src_repr, tgt_repr, shared_repr):
        """改进的对比学习损失"""
        batch_size = src_repr.size(0)

        # 正样本：同一用户的不同域表示应该相似
        pos_sim = F.cosine_similarity(src_repr, tgt_repr, dim=-1) / self.temperature

        # 负样本：不同用户的表示应该不相似
        src_expanded = src_repr.unsqueeze(1).expand(-1, batch_size, -1)
        tgt_expanded = tgt_repr.unsqueeze(0).expand(batch_size, -1, -1)
        neg_sim = F.cosine_similarity(src_expanded, tgt_expanded, dim=-1) / self.temperature

        # 创建标签矩阵
        labels = torch.eye(batch_size, device=src_repr.device)

        # InfoNCE损失
        logits = neg_sim
        loss = F.cross_entropy(logits, labels)

        return loss

    def domain_adversarial_loss(self, aligned_src, aligned_tgt, src_domain_pred, tgt_domain_pred):
        """域对抗损失"""
        # 域标签
        src_domain_labels = torch.zeros_like(src_domain_pred)
        tgt_domain_labels = torch.ones_like(tgt_domain_pred)

        # 域分类损失
        domain_loss = F.binary_cross_entropy_with_logits(src_domain_pred, src_domain_labels) + \
                      F.binary_cross_entropy_with_logits(tgt_domain_pred, tgt_domain_labels)

        return domain_loss

    def calculate_loss(self, interaction):
        # 获取批次数据
        src_u = interaction[self.SOURCE_USER_ID]
        src_i = interaction[self.SOURCE_ITEM_ID]
        src_y = interaction[self.SOURCE_LABEL].float()
        tgt_u = interaction[self.TARGET_USER_ID]
        tgt_i = interaction[self.TARGET_ITEM_ID]
        tgt_y = interaction[self.TARGET_LABEL].float()

        # 改进的负样本处理 - 不完全删除，而是降采样
        if self.config['mask_delete']:
            # 保持正负样本比例平衡
            src_pos_mask = src_y > 0
            src_neg_mask = src_y <= 0
            src_neg_sample = torch.rand_like(src_y) < 0.3  # 保留30%的负样本
            src_keep = src_pos_mask | (src_neg_mask & src_neg_sample)
            src_u, src_i, src_y = src_u[src_keep], src_i[src_keep], src_y[src_keep]

            tgt_pos_mask = tgt_y > 0
            tgt_neg_mask = tgt_y <= 0
            tgt_neg_sample = torch.rand_like(tgt_y) < 0.3
            tgt_keep = tgt_pos_mask | (tgt_neg_mask & tgt_neg_sample)
            tgt_u, tgt_i, tgt_y = tgt_u[tgt_keep], tgt_i[tgt_keep], tgt_y[tgt_keep]

        # 对齐批次大小
        bs = min(src_u.size(0), tgt_u.size(0))
        src_u, src_i, src_y = src_u[:bs], src_i[:bs], src_y[:bs]
        tgt_u, tgt_i, tgt_y = tgt_u[:bs], tgt_i[:bs], tgt_y[:bs]

        # 获取历史行为
        src_hist_id = self.src_hist_id[src_u]
        tgt_hist_id = self.tgt_hist_id[tgt_u]

        # 历史行为masking
        if self.mask_prob > 0:
            src_mask = (src_hist_id != 0) & (torch.rand_like(src_hist_id.float()) < self.mask_prob)
            src_hist_id = src_hist_id.masked_fill(src_mask, 0)
            tgt_mask = (tgt_hist_id != 0) & (torch.rand_like(tgt_hist_id.float()) < self.mask_prob)
            tgt_hist_id = tgt_hist_id.masked_fill(tgt_mask, 0)

        src_hist_id, _ = truncate_history(src_hist_id, None, self.max_his_len)
        tgt_hist_id, _ = truncate_history(tgt_hist_id, None, self.max_his_len)

        # 获取embeddings
        src_u_spec = self.specific_user_emb[0](src_u)
        tgt_u_spec = self.specific_user_emb[1](tgt_u)
        src_i_spec = self.specific_item_emb[0](src_i)
        tgt_i_spec = self.specific_item_emb[1](tgt_i)

        # 特定域历史embedding
        src_hist_emb = self.specific_item_emb[0](src_hist_id)
        tgt_hist_emb = self.specific_item_emb[1](tgt_hist_id)

        # 共享用户embedding
        src_u_shared = self.shared_user_emb(src_u)
        tgt_u_shared = self.shared_user_emb(tgt_u)

        # 关键改进：使用共享item embedding计算共享表示
        src_hist_shared = self.shared_item_emb(src_hist_id)
        tgt_hist_shared = self.shared_item_emb(tgt_hist_id)

        # 特定域聚合
        src_spec = self.specific_agg[0](src_u_spec, src_hist_emb)
        tgt_spec = self.specific_agg[1](tgt_u_spec, tgt_hist_emb)

        # 共享域聚合（使用共享item embedding）
        src_shared = self.shared_agg(src_u_shared, src_hist_shared)
        tgt_shared = self.shared_agg(tgt_u_shared, tgt_hist_shared)

        # 域对齐
        aligned_src, aligned_tgt, src_domain_pred, tgt_domain_pred = self.domain_alignment(src_shared, tgt_shared)

        # 最终用户表示
        src_user_repr = src_spec + aligned_src
        tgt_user_repr = tgt_spec + aligned_tgt

        # 预测
        pred_src = (src_user_repr * src_i_spec).sum(-1)
        pred_tgt = (tgt_user_repr * tgt_i_spec).sum(-1)

        # 主要损失
        loss_main = self.loss_fn(pred_src, src_y) + self.loss_fn(pred_tgt, tgt_y)

        # 对比学习损失
        cl_loss = self.contrastive_loss(src_shared, tgt_shared, aligned_src)

        # 域对抗损失
        domain_loss = self.domain_adversarial_loss(aligned_src, aligned_tgt, src_domain_pred, tgt_domain_pred)

        # 总损失
        total_loss = loss_main + self.lambda_cl * cl_loss + self.lambda_domain * domain_loss

        return total_loss

    def _forward_user(self, user, domain='target'):
        """用户表示前向传播"""
        hist_id = self.tgt_hist_id[user] if domain == 'target' else self.src_hist_id[user]
        hist_id, _ = truncate_history(hist_id, None, self.max_his_len)

        domain_idx = 1 if domain == 'target' else 0
        u_spec = self.specific_user_emb[domain_idx](user)
        hist_emb = self.specific_item_emb[domain_idx](hist_id)
        u_shared = self.shared_user_emb(user)
        hist_shared = self.shared_item_emb(hist_id)

        spec = self.specific_agg[domain_idx](u_spec, hist_emb)
        shared = self.shared_agg(u_shared, hist_shared)

        # 在推理时不使用域对齐
        return spec + shared

    def predict(self, interaction):
        user = interaction[self.TARGET_USER_ID]
        item = interaction[self.TARGET_ITEM_ID]
        user_rep = self._forward_user(user, domain='target')
        item_rep = self.specific_item_emb[1](item)
        return (user_rep * item_rep).sum(-1)

    def full_sort_predict(self, interaction):
        user = interaction[self.TARGET_USER_ID]
        user_rep = self._forward_user(user, domain='target')
        all_items = self.specific_item_emb[1].weight[:self.target_num_items]
        scores = torch.matmul(user_rep, all_items.t())
        return scores.view(-1)